from dataset import DatasetLoader
import torch.utils.data as Data

class DatasetProcess:
    def __init__(self, train_dir, val_dir, batch_size, num_workers):
        dataset_train = DatasetLoader(root_dir=train_dir)
        dataset_val = DatasetLoader(root_dir=val_dir)

        self.dataTrain = Data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        self.dataVal = Data.DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False, num_workers=num_workers)


    def get_data(self):
        return self.dataTrain, self.dataVal


